/*
 * Copyright (c) 2018, apexes.net. All rights reserved.
 *
 *         http://www.apexes.net
 *
 */
package net.apexes.commons.eventbus;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;

/**
 * 
 * @author <a href="mailto:hedyn@foxmail.com">HeDYn</a>
 *
 */
public class EventBusFilter extends PriorityHub {
    
    private final IPublisher publisher;
    
    public EventBusFilter(IPublisher publisher) {
        super(IEventFilter.class);
        verifyNotNull(publisher, "publisher");
        this.publisher = publisher;
    }
    
    /**
     * 
     * @param address
     * @param filter
     */
    public <T> void register(final String address, final IEventFilter<T> filter) {
        verifyNotNull(address, "address");
        verifyNotNull(filter, "filter");
        super.register(address, filter);
    }
    
    /**
     * 
     * @param address
     * @param filter
     */
    public <T> void unregister(String address, IEventFilter<T> filter) {
        verifyNotNull(address, "address");
        verifyNotNull(filter, "filter");
        super.unregister(address, filter);
    }
    
    @Override
    protected Set<AnnotationPriority> getAnnotationPrioritys(Object object) {
        Set<AnnotationPriority> filters = new LinkedHashSet<>();
        Class<?> clazz = object.getClass();
        Method[] methods = clazz.getDeclaredMethods();
        for (Method method : methods) {
            if (method.isAnnotationPresent(EventFilter.class)) {
                Class<?>[] parameterTypes = method.getParameterTypes();
                if (parameterTypes.length != 2) {
                    throw new IllegalArgumentException("The method is invalid. method=" + method);
                }
                Class<?> chainParamType = parameterTypes[1];
                if (!EventFilterChain.class.isAssignableFrom(chainParamType)) {
                    throw new IllegalArgumentException("The method is invalid. method=" + method);
                }
                AnnotationEventFilter filter = new AnnotationEventFilter(object, method);
                filters.add(filter);
            }
        }
        return filters;
    }
    
    @Override
    protected <E> void doPost(String address, E event, PriorityList<E> prioritys) {
        EventFilterChain<E> chain = new EventFilterChainImpl<>(publisher, address, prioritys);
        chain.next(event);
    }

    private static final String CHAIN_IMPL_CLASS_NAME = EventFilterChainImpl.class.getName();
    
    /**
     * 
     * @author <a href="mailto:hedyn@foxmail.com">HeDYn</a>
     *
     * @param <E>
     */
    private static class EventFilterChainImpl<E> implements EventFilterChain<E> {
        
        private final IPublisher publisher;
        private final String address;
        private final Iterator<IPriority<E>> iterator;
        private final Set<String> callerNames;
        
        EventFilterChainImpl(IPublisher publisher, String address, PriorityList<E> prioritys) {
            this.publisher = publisher;
            this.address = address;
            iterator = prioritys.iterator();
            callerNames = new HashSet<>();
        }

        @Override
        public void next(E event) {
            String callerName = getCallerName();
            if (callerNames.contains(callerName)) {
                throw new IllegalStateException();
            }
            callerNames.add(callerName);
            if (iterator.hasNext()) {
                IEventFilter<E> filter = (IEventFilter<E>) iterator.next();
                filter.filter(event, this);
            } else {
                publisher.publish(address, event);
            }
        }
        
        /**
         * 获取调用本类中方法的方法
         * @return
         */
        private String getCallerName() {
            StackTraceElement stack[] = Thread.currentThread().getStackTrace();
            int i = 0;
            while (i < stack.length) {
                StackTraceElement frame = stack[i];
                String className = frame.getClassName();
                if (className.equals(CHAIN_IMPL_CLASS_NAME)) {
                    break;
                }
                i++;
            }
            while (i < stack.length) {
                StackTraceElement frame = stack[i];
                String className = frame.getClassName();
                if (!className.equals(CHAIN_IMPL_CLASS_NAME)) {
                    return String.format("%s.%s", className, frame.getMethodName());
                }
                i++;
            }
            return null;
        }
        
    }

}
